-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More fixes for 2023.12 support #166
Conversation
- Ensure the arrays that are created are created on the same device as x. (fixes data-apis#177) - Make clip() work with dask.array. The workaround avoid uint64 -> float64 promotion does not work here. (fixes data-apis#176) - Fix loss of precision when clipping a float64 tensor with torch due to the scalar being converted to a float32 tensor.
I'm not sure if all the details here are correct. See data-apis#127 (comment).
Some of these things have to be inspected manually, and I'm not completely certain everything here is correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm hoping someone with more knowledge of pytorch can review this. I've made various assumptions here, which I'll try to outline in comments below. If anyone can confirm whether those assumptions are valid, that would be helpful.
""" | ||
return { | ||
"boolean indexing": True, | ||
"data-dependent shapes": True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm assuming boolean indexing and data-dependent shapes (i.e., functions like unique
) always work in pytorch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you clarify what you mean by always work? unique
and nonzero
will generally work in PyTorch the same way they do for NumPy, considering normal eager execution.
With the compiler, maybe depending on the opinions chosen, is a graph break considered working?
Complex dtypes will not work, but not due to issues with data dependent shapes, it is because our implementation will unconditionally sort the input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hadn't even considered the torch compiler. But I think as long as it functions without an error, that should be considered working (if that's not the case, then it might actually be worth flipping this flag in a compiler context, assuming that's easy).
These flags exist for libraries like dask or JAX that straight up don't allow these types of operations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case I would say you are probably correct the way you have it. Those are intended to be supported and lack of support/conditional support in the compile context is considered a deficiency.
'cpu' | ||
|
||
""" | ||
return torch.device("cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm assuming the default device in pytorch is always "cpu" (note there is currently an unfortunate distinction between the "default device" and "current device" in the standard. See data-apis/array-api#835). By "default device", I mean the device that is used by default when pytorch is first started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is correct.
'indexing': torch.int64} | ||
|
||
""" | ||
default_floating = torch.get_default_dtype() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I'm understanding the docs correctly, this function will always give the default floating-point dtype (i.e., what is used by default for something like torch.ones()
with no dtype argument).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is true. However that "default" in torch is not static. It can be changed by the user, so that would be the "current default" and the "default default" would be torch.float32
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a note to myself: I checked and torch does correctly fail if you set the default dtype to float64 and try to create a tensor on a device that doesn't support float64:
>>> torch.set_default_dtype(torch.float64)
>>> torch.asarray(0., device='mps')
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
(unlike certain other libraries that silently map float64 back to float32)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so that would be the "current default" and the "default default" would be torch.float32
I'm not sure which this should use then. I think it should be the "current default", but the meaning of "default" is ambiguous. I mentioned this at data-apis/array-api#835 (comment)
|
||
""" | ||
default_floating = torch.get_default_dtype() | ||
default_complex = torch.complex64 if default_floating == torch.float32 else torch.complex128 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docs explicitly state that the default complex dtype always matches the floating dtype. https://pytorch.org/docs/stable/generated/torch.set_default_dtype.html
array_api_compat/torch/_info.py
Outdated
""" | ||
default_floating = torch.get_default_dtype() | ||
default_complex = torch.complex64 if default_floating == torch.float32 else torch.complex128 | ||
default_integral = torch.asarray(0, device=device).dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to access this that doesn't require creating a tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can also just hard code it, we don't have a default for integers internally. If this were to change it would break bc and would need to be updated everywhere an integer tensor is produced, manually.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean it's always int64
? I wasn't sure if this would be different for certain devices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have a concept of the default integer type. It would always need to be provided or the type is deduced from the argument that is what happens in this case, python ints can be larger than int32 so we use int64.
I mean that I do not know of any mechanism for changing the behavior you see here. Our deduction rules for creation from a non array type can be seen here. If the device does not support int64 I would expect you to see an error like you do with MPS and float64 (which I did not know would happen, I expect they had to do some work to make sure that error gets raised.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that that layer of code is going to run "above" the dispatch to device specific logic (as in closer to python), so there is no way for the device to influence how that type is deduced, the device will be involved when the underlying tensor object is constructed where it will only see that a dtype
argument has been provided.
"real floating": default_floating, | ||
"complex floating": default_complex, | ||
"integral": default_integral, | ||
"indexing": default_integral, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm assuming the default indexing type is always the same as the default integer dtype (the default indexing type is a concept in the array API for functions that return indexing arrays. For example, nonzero
should return an array with the default indexing type https://data-apis.org/array-api/latest/API_specification/generated/array_api.nonzero.html)
array_api_compat/torch/_info.py
Outdated
uint8 = getattr(torch, "uint8", None) | ||
uint16 = getattr(torch, "uint16", None) | ||
uint32 = getattr(torch, "uint32", None) | ||
uint64 = getattr(torch, "uint64", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these the only dtypes that can be undefined (I know newer torch versions have them, but I want to make sure older versions work here too).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also I know the support for some of these is limited. Would it be more correct to always omit them, even when they are technically defined here? They aren't really fully supported from the point of view of the array API standard.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would say that depends on the purpose of the declaration. If it is to list the definitions for data types provided by the library leave them. If the point is to declare the data types supported by the array api then drop them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I should probably remove them then. Too many array API things don't actually work with them, and this API is supposed to be a better way to check that than hasattr(xp, 'uint64')
.
del res[k] | ||
continue | ||
try: | ||
torch.empty((0,), dtype=v, device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the best way to test if a dtype is supported on a given device?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What defines supported?
If dtype and device are valid this should always work, storage is untyped so the dtype is only used to compute the number of bytes needed.
Any given operator may or may not correctly dispatch for the underlying dtype. If the type is not explicitly handed for a given operator it will throw.
return res | ||
raise ValueError(f"unsupported kind: {kind!r}") | ||
|
||
@cache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm assuming it's safe to cache the output of this function (it's potentially expensive since it constructs tensors on the given device to test if a dtype is supported).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is currently no way to define a new dtype than adding an entry to a enum in the source code.
del res[k] | ||
return res | ||
|
||
@cache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm assuming the set of legal devices never changes at runtime and can be cached.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use hooks to register a custom device backend dynamically at runtime. I am unsure if this will simply add a new accepted device type string, or if the privateuseone
string is overwritten by it.
# currently supported devices. To do this, we first parse the error | ||
# message of torch.device to get the list of all possible types of | ||
# device: | ||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the big one. This is code is based on some discussions with @pearu. To get the list of possible torch devices, we first parse an error message, then check which of those devices actually work. If there is any better way of doing this, please let me know. I would definitely prefer if this functionality were built-in to pytorch (ditto for the other functions here too).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is unfortunate. I think there might be a less expensive, but more ugly way to do this. Would adding something to surface this info even help, or would you still need to have something like this to support older versions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding something would definitely be helpful. Note that ideally, pytorch will implement this exact API here, since it's part of the array API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed that pytorch sort of has APIs to do this better, e.g., https://pytorch.org/docs/stable/generated/torch.mps.device_count.html#torch.mps.device_count, but they are not consistent across all device types, and I didn't want to hard-code all the possible device types here since torch seems to support a lot of them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: I think that API should exist for optional devices. And you can find the module programmatically with torch.get_device_module(<name>)
. The always available devices are of course special, but I can look into handling this on the pytorch side.
'int64': cupy.int64} | ||
|
||
""" | ||
# TODO: Does this depend on device? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@leofang is it possible for cupy to not support some of the array API dtypes depending on a given device?
See the discussion at data-apis#166 (comment)
So a couple of small things are still missing here #127. Most notable is |
#127
Fixes #152